Mouse MOp (MERFISH)¶

Importing¶

In [1]:
import enclus
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import warnings
warnings.filterwarnings('ignore')

Preprocessing scRNA-ref¶

We used a mouse MOp snRNA-seq as reference dataset, the raw dataset in h5ad format (mop_sn_tutorial.h5ad) is available in here, the pre-pocessed scRNA-ref (Ckpts_scRefs/MOp/Ref_snRNA_mop_qc3_2Kgenes.h5ad) as well as other relevent materials involved in the following example in here.

In [2]:
sc_data = sc.read('datasets/Ref_snRNA_mop_qc3_2Kgenes.h5ad') #20 cell types (13516, 21158)
cell_type_column = 'subclass_label'
sc_data.X.max()
Out[2]:
np.float32(7.5840874)
In [3]:
# Extract information to identify differentially expressed (DE) genes
st_data = sc.read('datasets/MERFISH_mop.h5ad')

# Load reference single-cell data
sc_data = sc.read('datasets/Ref_snRNA_mop_qc3_2Kgenes.h5ad')

st_data.obs_names_make_unique()
st_data.var_names_make_unique()

st_data.obs = st_data.obs.rename(columns={'X': 'x', 'Y': 'y'})
st_data.obs = st_data.obs.rename(columns={'subclass': 'cell_type'})
sc_data.obs = sc_data.obs.rename(columns={'subclass_label': 'cell_type'})

# Extract 'x' and 'y' columns from st_data.obs
spatial_data = st_data.obs[['x', 'y']].values
# Store the extracted x and y columns into st_data.obsm['spatial']
st_data.obsm['spatial'] = spatial_data

from scipy.sparse.csc import csc_matrix
from scipy.sparse.csr import csr_matrix
if isinstance(sc_data.X, csc_matrix) or isinstance(sc_data.X, csr_matrix):
    sc_data.X = sc_data.X.toarray()

type(st_data.X), type(sc_data.X)

# Extract cell type array
sp_adata_ct = np.array(st_data.obs['cell_type'])
# Pre-process spatial data: merge certain subclass labels
sp_adata_ct = np.array([_.replace('L4/5 IT', 'L5 IT') for _ in sp_adata_ct])  # including 'SMC', 'L6 IT Car3', 'L4/5 IT'
st_data.obs['cell_type'] = sp_adata_ct

# Find overlapping cell types between spatial and reference data
overlap_ct = np.array(list(set(np.unique(st_data.obs['cell_type'])) &
                           set(np.unique(sc_data.obs['cell_type']))))

st_data = st_data[st_data.obs['cell_type'].isin(overlap_ct)].copy()
sc_data = sc_data[sc_data.obs['cell_type'].isin(overlap_ct)].copy()

merfish_data = st_data
ref_data = sc_data

# Extract expression matrix (if stored in .raw)
counts_merfish = merfish_data.X

# Extract spatial coordinates (assume stored in obsm['spatial'])
coords_merfish = merfish_data.obsm['spatial']

# Compute nUMI (total counts per spot), convert to 1D array
nUMI_merfish = counts_merfish.sum(axis=1)

# Extract reference expression matrix
counts_ref = ref_data.X

# Extract cell type information (assume stored in .obs['cell_type'])
cell_types_ref = ref_data.obs['cell_type']

# Extract nUMI information (assume stored in .obs['nUMI'])
nUMI_ref = ref_data.obs['nUMI']

# Save MERFISH counts and coordinates data
merfish_counts_df = pd.DataFrame(
    counts_merfish,
    index=merfish_data.obs_names,
    columns=merfish_data.var_names
)
merfish_counts_df.to_csv('MERFISH_counts.csv')

coords_df = pd.DataFrame(
    coords_merfish,
    index=merfish_data.obs_names,
    columns=['x', 'y']
)
coords_df.to_csv('MERFISH_coords.csv')

# Save reference data: counts, cell types, and nUMI
ref_counts_df = pd.DataFrame(
    counts_ref,
    index=ref_data.obs_names,
    columns=ref_data.var_names
)
ref_counts_df.to_csv('Ref_counts.csv')

meta_data_ref_df = pd.DataFrame({
    'barcode': ref_data.obs_names,
    'cluster': cell_types_ref,
    'nUMI': nUMI_ref
})
meta_data_ref_df.to_csv('Ref_meta_data.csv')

Preprocessing MERFISH data¶

We used single - cell data to identify sets of cell type marker genes and highly variable genes, and removed redundant cell types in spatial transcriptomics.

In [ ]:
sc_data.raw = sc_data.copy()
cell_type_column = 'subclass_label'
 
markers_df = pd.DataFrame(sc_data.uns["rank_genes_groups"]["names"]).iloc[0:30, :]
markers = list(np.unique(markers_df.melt().value.values))
markers = list(set(sc_data.var.loc[sc_data.var['highly_variable']==1].index)|set(markers)) # highly variable genes 1931 + cell type marker genes
print(len(markers))

st_data = sc.read('./datasets/MERFISH_mop.h5ad') #23 cell types (5551, 254)
st_data.obs_names_make_unique()
st_data.var_names_make_unique()

st_data.obs = st_data.obs.rename(columns = {'X':'x', 'Y':'y'})
st_data.obs = st_data.obs.rename(columns = {'subclass':'cell_type'})
sc_data.obs = sc_data.obs.rename(columns = {'subclass_label':'cell_type'})


# Extract 'x' and 'y' columns from st_data.obs
spatial_data = st_data.obs[['x', 'y']].values
# Store the extracted x and y columns into st_data.obsm['spatial']
st_data.obsm['spatial'] = spatial_data

sc.pp.log1p(st_data)
print(st_data.X.max())
merfish_genes = st_data.var.index.values.tolist() 

add_genes = 'Nos1ap Erbb4 Atp2b4 Adamts3 Cdh4 Celf2 Crispld1 Esrrg Htr4 Kcnh5 Prkg1 3110035E14Rik Garnl3 Pvalb Cplx3 Fam84b Slc17a6 Tenm3 Opalin Cdh12 Enpp6 Kcng1 Cux2 Otof Rorb Rspo1 Sulf2 Fezf2 Osr1'.split() # some important genes that we interested

# markers = markers+merfish_genes+add_genes+ligand_recept
markers = np.unique(markers+add_genes)

print("Markers:",len(markers))

from scipy.sparse.csc import csc_matrix
from scipy.sparse.csr import csr_matrix
if isinstance(sc_data.X, csc_matrix) or isinstance(sc_data.X, csr_matrix):
    sc_data.X = sc_data.X.toarray()

type(st_data.X),type(sc_data.X)

# fig, axs = plt.subplots(1, 1, figsize=(10, 10))
# sc.pl.umap(
#     sc_data, color=cell_type_column, size=15, frameon=False, show=False, ax=axs,legend_loc='on data'
# )
# plt.tight_layout()

sp_adata_ct = np.array(st_data.obs['cell_type'])
# pre-process spatial data 
sp_adata_ct = np.array([_.replace('L4/5 IT', 'L5 IT') for _ in sp_adata_ct]) #drop 'SMC', 'L6 IT Car3', 'L4/5 IT'
st_data.obs['cell_type'] = sp_adata_ct
overlap_ct = np.array(list(set(np.unique(st_data.obs['cell_type'])) & set(np.unique(sc_data.obs['cell_type']))))

st_data =  st_data[st_data.obs['cell_type'].isin(overlap_ct)].copy()
sc_data = sc_data[sc_data.obs['cell_type'].isin(overlap_ct)].copy()

# st_data.write('./datasets/mop/processed_MERFISH_mop.h5ad')
# sc_data.write('./datasets/mop/processed_snRNA_mop.h5ad')
1931
5.3889804
Markers: 1938

Load processed data¶

In [3]:
sc_data = sc.read_h5ad('./datasets/mop/processed_snRNA_mop.h5ad')
st_data = sc.read_h5ad('./datasets/mop/processed_MERFISH_mop.h5ad')

print(sc_data,st_data)
add_genes = 'Nos1ap Erbb4 Atp2b4 Adamts3 Cdh4 Celf2 Crispld1 Esrrg Htr4 Kcnh5 Prkg1 3110035E14Rik Garnl3 Pvalb Cplx3 Fam84b Slc17a6 Tenm3 Opalin Cdh12 Enpp6 Kcng1 Cux2 Otof Rorb Rspo1 Sulf2 Fezf2 Osr1'.split() # some important genes that we interested
AnnData object with n_obs × n_vars = 13516 × 21158
    obs: 'QC', 'batch', 'class_color', 'class_id', 'class_label', 'cluster_color', 'cluster_labels', 'dataset', 'date', 'ident', 'individual', 'nCount_RNA', 'nFeature_RNA', 'nGene', 'nUMI', 'project', 'region', 'species', 'subclass_id', 'cell_type', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'n_counts', 'subclass_label_R'
    var: 'mt', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'n_cells', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'Marker', 'MERFISH_gene'
    uns: 'hvg', 'neighbors', 'pca', 'rank_genes_groups', 'subclass_label_colors', 'umap'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs'
    layers: 'logcounts'
    obsp: 'connectivities', 'distances' AnnData object with n_obs × n_vars = 5381 × 254
    obs: 'sample_id', 'slice_id', 'class_label', 'cell_type', 'label', 'x', 'y'
    obsm: 'spatial'
In [4]:
sc_data.X.max(),st_data.X.max()
Out[4]:
(np.float32(7.5840874), np.float32(5.3889804))

Train SpateCV model to impute gene¶

After the model was trained here, only the gene imputation function was used.

In [5]:
enclus_model = enclus.ENCLUS(spatial_data = st_data, sc_data = sc_data,
                    num_layers=3,
                    num_neurons=1024,
                    latent_dim=512,
                    k_nearest=16,
                    num_cov_genes=64,
                    num_HVG=1024,
                    sc_genes=add_genes,
                    spatial_dist="pois",
                    sc_dist="nb",
                    spatial_coeff=1,  
                    sc_coeff=1,   
                    kl_coeff=0.03, 
                    n_clusters=10,
                    tau=0.1, 
                    gamma=0.1,
                    adaptive_weights=True,
                    early_stopping=True,
                    patience=30,
                    num_heads=10,
                    head_dim=168,
                    distance_metric='euclidean'
                    )
#train model
enclus_model.train(training_steps=4628,
    batch_size=2048,
    verbose=100,
    init_lr=0.00001,
    decay_steps=4000)

enclus_model.impute_genes()
st_data.obsm['enclus_latent'] = enclus_model.spatial_data.obsm['enclus_latent']
st_data.obsm['imputation'] = enclus_model.spatial_data.obsm['imputation']
sc_data.obsm['enclus_latent'] = enclus_model.sc_data.obsm['enclus_latent']
sc_data shape and st_data shape: (13516, 1938) (5381, 247)
Initializing CVAE
Finished Initializing ENCLUS
Initializing cluster centers...
 | spatial_w: 3.75 sc_w: 4.04 cov_w: 6.19 kl_w: 0.80 cluster_w: 1.05:  21%|██▏       | 986/4628 [2:40:29<9:52:47,  9.77s/it] 
Early stopping triggered

Finished imputing missing gene for spatial data! See 'imputation' in obsm of ENCLUS.spatial_data

View the results of gene imputation¶

In [6]:
st_data.obsm['imputation']
Out[6]:
index 1700022I11Rik 1810046K07Rik 5730522E02Rik Acta2 Adam2 Adamts2 Adamts4 Adra1b Alk Ankfn1 ... Tal1 1700047M11Rik Tbx1 Hepacam Aplp1 Slc2a1 Flt3 Ldb2 Tnfrsf10b Gm26522
100119755510557417791056480683612014915 0.033958 0.016518 0.467514 0.055450 0.011491 0.022482 0.121522 0.276107 0.012841 0.001613 ... 0.006594 0.015793 0.006942 0.154774 2.848228 0.282262 0.036911 0.967486 0.009771 0.030900
100132293312011676013834246988384644937 0.017182 0.005288 0.038738 0.007840 0.005177 0.135361 0.324683 0.036284 0.018930 0.005991 ... 0.004786 1.213946 0.006967 1.754055 2.685546 0.339692 0.010658 0.077153 0.012750 0.003057
100141477907895285159644541629937136293 0.012541 0.033874 0.024705 0.016275 0.013458 0.007174 0.004996 0.047633 0.017084 0.101747 ... 0.013227 0.024183 0.023657 0.450264 1.117197 0.769400 0.044615 0.563372 0.010724 0.031875
100170194898756150593503685585478903524 0.025831 0.006050 0.124738 0.006358 0.005296 0.008292 0.010910 0.067310 1.180437 0.035100 ... 0.002178 0.006100 0.006054 0.055215 0.607525 0.071660 0.826093 0.144919 0.008019 0.013541
100221098919514132063709706431102588200 0.013862 0.007168 0.099484 0.007788 0.016363 0.045896 0.028138 0.792513 0.006974 0.002345 ... 0.009607 0.015191 0.023333 0.156937 1.545135 0.106255 0.033739 0.398914 0.003060 0.013868
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
99761023417729303996882966126171608555 0.021530 0.010730 0.064236 0.016746 0.065895 0.022191 0.009726 0.274346 0.006460 0.012128 ... 0.023330 0.018904 0.016180 0.116808 2.861451 0.355431 0.035919 1.282821 0.014298 0.019250
9976139147878378112755895136358908028 0.027440 0.009032 0.028702 0.003477 0.009816 0.006668 0.001906 0.101085 0.046885 0.008719 ... 0.008437 0.016010 0.005064 1.883570 0.878763 0.423456 0.006838 0.079814 0.011676 0.007929
99775860964963590854647880769830485687 0.014351 0.013585 0.039534 0.005633 0.023213 0.010681 0.005401 0.428636 0.013966 0.009738 ... 0.016675 0.009954 0.020660 0.229057 1.067936 0.623710 0.061817 0.729681 0.007573 0.012526
99800137893930745778774961383897495957 0.019881 0.025649 0.119920 0.028423 0.070167 0.079435 0.130550 1.304224 0.008272 0.002989 ... 0.007429 0.014031 0.029708 0.105470 3.212263 0.087636 0.035806 0.760579 0.007636 0.021172
99972994093375169663255672333494574767 0.017218 0.035808 0.046467 0.016681 0.023042 0.009058 0.003639 0.146057 0.024576 0.091198 ... 0.018043 0.016521 0.021321 0.333742 1.023737 0.742046 0.043975 0.830922 0.010382 0.032604

5381 rows × 1938 columns

View the latent representations of scRNA-seq and ST¶

Defining cell type color palette

In [ ]:
import umap.umap_ as umap
fit = umap.UMAP(
    n_neighbors = 50,
    min_dist = 0.5,
    n_components = 2,
)

latent_umap = fit.fit_transform(np.concatenate([st_data.obsm['enclus_latent'], sc_data.obsm['enclus_latent']], axis = 0))

st_data.obsm['latent_umap'] = latent_umap[:st_data.shape[0]]
sc_data.obsm['latent_umap'] = latent_umap[st_data.shape[0]:]

lim_arr = np.concatenate([st_data.obsm['latent_umap'], sc_data.obsm['latent_umap']], axis = 0)


delta = 1
pre = 0.1
xmin = np.percentile(lim_arr[:, 0], pre) - delta
xmax = np.percentile(lim_arr[:, 0], 100 - pre) + delta
ymin = np.percentile(lim_arr[:, 1], pre) - delta
ymax = np.percentile(lim_arr[:, 1], 100 - pre) + delta

color_dict = {'Astro': '#1f77b4',
 'Endo': '#aec7e8',
 'L2/3 IT': '#ff7f0e',
 'L5 ET': '#ffbb78',
 'L5 IT': '#2ca02c',
 'L5/6 NP': '#98df8a',
 'L6 CT': '#d62728',
 'L6 IT': '#ff9896',
 'L6b': '#9467bd',
 'Lamp5': '#c5b0d5',
 'Micro': '#8c564b',
 'Oligo': '#c49c94',
 'OPC': '#e377c2',
 'Peri': '#f7b6d2',
 'Pvalb': '#7f7f7f',
 'PVM': '#a3a2a2',
 'Sncg': '#bcbd22',
 'Sst': '#dbdb8d',
 'Vip': '#17becf',
 'VLMC': '#9edae5',
 'None': '#dbd9d9',
 'SMC': '#B87BCE',
 'L6 IT Car3': '#82A8CE',
  'L4/5 IT': '#2ca02c',
 }
#20 cell types  
labelnames = ['Astro', 'Endo', 'L2/3 IT', 'L5 IT', 'L5 ET', 'L5/6 NP', 'L6 IT', 'L6 CT', 'L6b', 'Lamp5', 'Micro', 'Oligo', 'OPC', 'Peri', 'Pvalb',
'PVM', 'Sncg', 'Sst', 'Vip', 'VLMC']

fig = plt.figure(figsize = (13,5))
plt.subplot(121)
sns.scatterplot(x = sc_data.obsm['latent_umap'][:, 0],
                y = sc_data.obsm['latent_umap'][:, 1], hue = sc_data.obs['cell_type'], s = 8, palette = color_dict,
                legend = False)
plt.title("snRNA-seq Latent")
plt.xlim([xmin, xmax])
plt.ylim([ymin, ymax])
plt.axis('off')

plt.subplot(122)
sns.scatterplot(x = st_data.obsm['latent_umap'][:, 0],
                y = st_data.obsm['latent_umap'][:, 1],  hue = st_data.obs['cell_type'], s = 8, palette = color_dict, legend = True)


legend = plt.legend(title = 'Cell Type', prop={'size': 12}, fontsize = '12',  markerscale = 3, ncol = 2, bbox_to_anchor = (1.2, 1))#, loc = 'lower left')
plt.setp(legend.get_title(),fontsize='12')
plt.title("st_data Latent")
plt.axis('off')
plt.tight_layout()
plt.xlim([xmin, xmax])
plt.ylim([ymin, ymax])
plt.savefig('./Result/mop/latent.pdf')
# plt.show()

Calculate the cosine similarity of each cell¶

In [10]:
from sklearn.metrics.pairwise import cosine_similarity
sp_adata = st_data.copy()
generated_cells = st_data.obsm['imputation'] 
sp_adata_SS = sp_adata.copy()
overlaped_genes = np.array(list(set(sp_adata_SS.var.index) & set(generated_cells.columns)))
sp_adata_SS = sp_adata_SS.copy()[:, overlaped_genes]
generated_cells = generated_cells.loc[:, overlaped_genes].T
temp = pd.DataFrame()
temp['cosine similarity'] = list(np.diag(cosine_similarity(sp_adata_SS.X.copy().T, generated_cells.values)))  
cosine = temp['cosine similarity'].mean()
print('cosine similarity:',temp['cosine similarity'].mean())

raw = sp_adata_SS.to_df()
impute = generated_cells.T
import scipy.stats as st
result = pd.DataFrame()
for label in raw.columns:
    if label not in impute.columns:
        spearmanr = 0
    else:
        raw_col = raw.loc[:, label]
        impute_col = impute.loc[:, label]
        impute_col = impute_col.fillna(1e-20)
        raw_col = raw_col.fillna(1e-20)
        spearmanr, _ = st.pearsonr(raw_col, impute_col)
    pearsonr_df = pd.DataFrame(spearmanr, index=["PCC"], columns=[label])
    result = pd.concat([result, pearsonr_df], axis=1)

print(result.median(axis=1))
cosine similarity: 0.6716988
PCC    0.622068
dtype: float32

Save results¶

In [14]:
sc_data.write('./datasets/mop/ENVI_Ref_snRNA_mop_qc3_2Kgenes.h5ad')
st_data = sc.read_h5ad('./datasets/mop/ENVI_mop.h5ad')
st_data.obsm['imputation'].to_csv('./Result/mop/SpateCV_impute.csv',header = 1, index = 1)

Downstream analysis¶

In [2]:
# sc_data = sc.read('./datasets/mop/ENVI_Ref_snRNA_mop_qc3_2Kgenes.h5ad')
st_data = sc.read('./datasets/mop/ENVI_mop.h5ad')
In [5]:
from matplotlib.gridspec import GridSpec
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable

x=st_data.obsm['spatial'][:, 0]
y=-st_data.obsm['spatial'][:, 1]

sc_data = sc.read_h5ad('./datasets/mop/processed_snRNA_mop.h5ad')
markers = list(set(sc_data.var.loc[sc_data.var['highly_variable']==1].index))
add_genes = 'Nos1ap Erbb4 Atp2b4 Adamts3 Cdh4 Celf2 Crispld1 Esrrg Htr4 Kcnh5 Prkg1 3110035E14Rik Garnl3 Pvalb Cplx3 Fam84b Slc17a6 Tenm3 Opalin Cdh12 Enpp6 Kcng1 Cux2 Otof Rorb Rspo1 Sulf2 Fezf2 Osr1'.split() # some important genes that we interested

markers = np.unique(markers+add_genes)
print("Markers:",len(markers))

st_genes = st_data.to_df().columns.tolist()
merfish_genes = [gene for gene in markers if gene in st_genes]
nonmerfish_genes = [gene for gene in add_genes if gene not in merfish_genes]
print(len(merfish_genes),len(nonmerfish_genes))
Markers: 1938
247 13

Method comparison¶

We compared the accuracy of several methods in predicting 1938 marker genes

In [18]:
df_SpateCV = st_data.obsm['imputation']
df_ENVI = pd.read_csv('./Result/mop/ENVI_impute.csv')
df_SpaGE = pd.read_csv('./Result/mop/SpaGE_impute.csv')
df_Tangram = pd.read_csv('./Result/mop/Tangram_impute.csv')
df_gimVI = pd.read_csv('./Result/mop/gimVI_impute.csv')
df_stPlus = pd.read_csv('./Result/mop/stPlus_impute.csv')

We Selected seven marker genes with clear spatial expression patterns

In [6]:
def plot_predictGene_comparison(predicted_genes,
                                df_SpateCV,
                                df_ENVI,
                                df_SpaGE,
                                df_Tangram,
                                df_gimVI,
                                df_stPlus,
                                st_data,
                                title="MERFISH test genes"):
    sns.set(style="white", context="paper", font_scale=2.5)
    
    # ========== 1. Unify the color range ==========
    # Ground Truth
    df_truth = st_data.to_df()[predicted_genes]  # (n_spots, n_genes)
    actual_expression = np.log(df_truth.values + 0.1)  # log transform
    
    # ENVIC
    envic_expression = np.log(df_SpateCV[predicted_genes].values + 0.1)
    # ENVI
    envi_expression = np.log(df_ENVI[predicted_genes].values + 0.1)
    # SpaGE
    spage_expression = np.log(df_SpaGE[predicted_genes].values + 0.1)
    # Tangram
    tangram_expression = np.log(df_Tangram[predicted_genes].values + 0.1)
    # gimVI
    gimVI_expression = np.log(df_gimVI[predicted_genes].values + 0.1)
    # stPlus
    stPlus_expression = np.log(df_stPlus[predicted_genes].values + 0.1)
    
    # Merge all results and calculate vmin / vmax uniformly
    combined_expression = np.concatenate(
        [actual_expression, envic_expression, envi_expression,spage_expression, tangram_expression, gimVI_expression, stPlus_expression],
        axis=0
    )
    vmin = np.percentile(combined_expression, 20)
    vmax = np.percentile(combined_expression, 95)
    
    # ========== 2. Set up the drawing grid (rows = number of genes, columns = number of methods + 1 for gene names) ==========
    n_genes = len(predicted_genes)
    n_methods = 7  # ground truth,ENVIC, SpaGE, Tangram, gimVI, stPlus
    n_total_cols = n_methods + 1  # An additional column is used for gene names
    
    # Set the column width ratio, with the first column slightly wider to accommodate gene names.
    fig = plt.figure(figsize=(2*n_total_cols, 2*n_genes), dpi=300)
    gs = GridSpec(n_genes, n_total_cols, figure=fig, wspace=0.1, hspace=0.01, width_ratios=[1] + [4]*n_methods)
    
    # Method name, including gene name column
    method_names = ['Gene', 'Ground Truth', 'SpateCV', 'ENVI','SpaGE', 'Tangram', 'gimVI', 'stPlus']
    
    # ========== 3. Plot scatter plots gene by gene and column by column ==========
    for row_idx, gene in enumerate(predicted_genes):
        for col_idx, (df_, method_name) in enumerate(zip(
            [None, df_truth, df_SpateCV,df_ENVI, df_SpaGE, df_Tangram, df_gimVI, df_stPlus], 
            ['Gene', 'Ground Truth', 'SpateCV', 'ENVI','SpaGE', 'Tangram', 'gimVI', 'stPlus']
        )):
            ax = fig.add_subplot(gs[row_idx, col_idx])
            
            if col_idx == 0:
                # The leftmost column is used to display the gene names
                ax.text(0.5, 0.5, gene, fontsize=16, ha='center', va='center', transform=ax.transAxes)
                ax.axis('off')
            elif col_idx == 1:
                # Ground Truth
                cvec = actual_expression[:, row_idx]
                scatter = ax.scatter(
                    x_coords := st_data.obsm['spatial'][:, 0],
                    y_coords := -st_data.obsm['spatial'][:, 1],
                    c=cvec,
                    cmap='Reds',
                    vmin=vmin,
                    vmax=vmax,
                    s=5,
                    edgecolor='none',
                    alpha=0.8
                )
                if row_idx == 0:
                    ax.set_title("Ground Truth", fontsize=14, pad=10)
                ax.set_aspect('equal')
                ax.axis('off')
            else:
                # Draw scatter plots of the expression values of each method
                cvec = np.log(df_[gene] + 0.1)
                scatter = ax.scatter(
                    x_coords,
                    y_coords,
                    c=cvec,
                    cmap='Reds',
                    vmin=vmin,
                    vmax=vmax,
                    s=5,
                    edgecolor='none',
                    alpha=0.8
                )
                
                if row_idx == 0:
                    ax.set_title(method_name, fontsize=14, pad=10)
                
                ax.set_aspect('equal')
                ax.axis('off')
    
    plt.savefig('./Result/mop/predictGeneCompare.pdf')
    plt.show()

# compare_gene = [gene for gene in add_genes if gene in st_genes]
compare_gene = [
'Osr1',
'Otof',
'Slc17a6',
'Fam84b',
'Opalin',
'Cdh12',
'Fezf2',
]
plot_predictGene_comparison(
    predicted_genes=compare_gene,
    df_SpateCV=df_SpateCV,
    df_ENVI=df_ENVI,
    df_SpaGE=df_SpaGE,
    df_Tangram=df_Tangram,
    df_gimVI=df_gimVI,
    df_stPlus=df_stPlus,
    st_data=st_data,
    title="MERFISH test genes"
)
No description has been provided for this image

Calculate the MAE between the predicted results and the true results¶

Through this 'calculate_mae' function, the MAE of compare_gene and merfish_gene can be calculated, and a bar chart can be drawn.

In [9]:
from sklearn.metrics import mean_absolute_error
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

def calculate_mae(predicted_genes, df_truth, df_methods, methods):
    mae_results = pd.DataFrame(index=predicted_genes, columns=methods)
    
    for method, df_pred in zip(methods, df_methods):
        for gene in predicted_genes:
            mae = mean_absolute_error(df_truth[gene], df_pred[gene])
            mae_results.at[gene, method] = mae
    
    mae_results = mae_results.astype(float)
    
    # Calculate the average MAE for each method
    average_mae = mae_results.mean().to_frame(name='Average MAE')
    
    return mae_results, average_mae

def plot_mae(mae_results, average_mae, methods):
    
    mae_long = mae_results.reset_index().melt(id_vars='index', value_vars=methods, 
                                             var_name='Method', value_name='MAE')
    mae_long.rename(columns={'index': 'Gene'}, inplace=True)
    
    plt.figure(figsize=(4, 3), dpi=300)
    sns.barplot(data=mae_long, x='Gene', y='MAE', hue='Method', palette='Set3')
    plt.title('MAE between Predicted Marker Genes and Ground Truth', fontsize=8)
    plt.legend(title='Method', fontsize=4, title_fontsize=6)
    plt.xticks(rotation=0, ha='center', fontsize=8)
    plt.yticks(rotation=0, ha='right', fontsize=8)
    plt.xlabel("")  
    plt.ylabel("")  
    sns.despine()
    plt.tight_layout()
    plt.savefig('./Result/mop/MAE-Compare.pdf')
    plt.show()

    plt.figure(figsize=(4,3), dpi=300)
    sns.set(style="whitegrid", context="paper", font_scale=1.1)
    ax = sns.barplot(
        data=average_mae.reset_index(),
        x='index',
        y='Average MAE',
        palette='Set3',
        width=0.5
    )
    plt.title('Average MAE across Predicted Marker Genes', fontsize=10, pad=12)

    plt.xticks(rotation=0, ha='center', fontsize=10)
    plt.yticks(rotation=0, ha='right', fontsize=10)
    plt.xlabel("") 
    plt.ylabel("") 
    # Remove the top and right borders
    sns.despine()
    plt.tight_layout()
    plt.savefig('./Result/mop/Average-MAE-Compare.pdf', bbox_inches='tight')
    # plt.savefig('./Result/mop/MERFISH-MAE-Gene.pdf')
    plt.show()
    

methods = ['SpateCV', 'ENVI', 'stPlus', 'SpaGE', 'Tangram', 'gimVI']
df_methods = [df_SpateCV, df_ENVI, df_stPlus, df_SpaGE, df_Tangram, df_gimVI]
df_truth = st_data.to_df()[merfish_genes]  # (n_spots, n_genes)

mae_results, average_mae = calculate_mae(compare_gene, df_truth, df_methods, methods)

plot_mae(mae_results, average_mae, methods)
No description has been provided for this image
No description has been provided for this image

Imputation of Non-MERFISH genes¶

In [10]:
def plot_predictNonGene_comparison(df_SpateCV, df_ENVI,df_SpaGE, df_Tangram, df_gimVI, df_stPlus, genes, 
                                   x, y, title='Non-MERFISH Gene Expression'):
    sns.set(style="white", context="paper", font_scale=2.5)
    
    all_expr_values = []
    for df_ in [df_SpateCV, df_SpaGE, df_Tangram, df_gimVI, df_stPlus]:
        if not set(genes).issubset(df_.columns):
            missing_genes = set(genes) - set(df_.columns)
            raise ValueError(f"The following genes do not exist in the DataFrame: {missing_genes}")
        cvec_all = np.log(df_[genes].values + 0.1).flatten()  
        all_expr_values.append(cvec_all)
    all_expr_values = np.concatenate(all_expr_values)

    vmin = np.percentile(all_expr_values, 20)
    vmax = np.percentile(all_expr_values, 95)

    n_genes = len(genes)
    n_methods = 6  # ENVIC, ENVI, SpaGE, Tangram, gimVI, stPlus
    n_total_cols = n_methods + 1  
    
    fig = plt.figure(figsize=(2*n_total_cols, 2*n_genes), dpi=300)
    gs = GridSpec(n_genes, n_total_cols, figure=fig, wspace=0.1, hspace=0.01, width_ratios=[1] + [4]*n_methods)

    method_names = ['Gene', 'SpateCV', 'ENVI','SpaGE', 'Tangram', 'gimVI', 'stPlus']

    for row_idx, gene in enumerate(genes):
        for col_idx, (df_, method_name) in enumerate(zip(
            [None, df_SpateCV, df_ENVI,df_SpaGE, df_Tangram, df_gimVI, df_stPlus], method_names
        )):
            ax = fig.add_subplot(gs[row_idx, col_idx])
            
            if col_idx == 0:
                ax.text(0.95, 0.5, gene, fontsize=16, ha='right', va='center', transform=ax.transAxes)
                ax.axis('off')
            else:
                try:
                    cvec = np.log(df_[gene] + 0.1)
                except KeyError:
                    raise KeyError(f"基因 '{gene}' 在方法 '{method_name}' 的DataFrame中不存在。")
                
                scatter = ax.scatter(
                    x=x,
                    y=y,
                    c=cvec,
                    cmap='Reds',
                    vmin=vmin,
                    vmax=vmax,
                    s=5,
                    edgecolor='none',
                    alpha=0.8
                )

                if row_idx == 0:
                    ax.set_title(method_name, fontsize=14, pad=15)

                ax.set_aspect('equal')
                ax.axis('off')


    # plt.savefig('./Result/mop/Non-MERFISHGene.pdf')
    plt.savefig('./Result/mop/Astro_makerGene.pdf')
    plt.show()

None_gene = [gene for gene in add_genes if gene not in st_genes]
genes = ['Gfap','Aqp4','Nfia','Hepacam','Nxn','Ptprz1','Gramd3'] # Astro DE genes
# Astro_maker = ['Nfia','Hepacam','Ptprz1']
plot_predictNonGene_comparison(
    df_SpateCV=df_SpateCV,
    df_ENVI=df_ENVI,
    df_SpaGE=df_SpaGE,
    df_Tangram=df_Tangram,
    df_gimVI = df_gimVI,
    df_stPlus = df_stPlus,
    genes=genes,
    x=x,
    y=y,
    title='Non-MERFISH Gene Expression'
)
No description has been provided for this image

All cell types

In [63]:
sns.set_context('paper',font_scale=2) 
plt.subplots(figsize=(5,5),dpi=300)
sns.scatterplot(data=st_data.obs, x="x", y="y", hue="cell_type",hue_order=labelnames,s=15,palette=color_dict)
plt.legend(title = 'Cell Type', prop={'size': 12}, fontsize = '12',  markerscale = 3, ncol = 2, bbox_to_anchor = (1, 1))
# plt.legend(bbox_to_anchor=(1.0,0.98), loc="upper left",framealpha=0,markerscale=1.5)
plt.gca().invert_yaxis()
plt.axis('off')
plt.savefig('./Result/mop/MOp-Cell_type.pdf', bbox_inches='tight')
plt.show()
No description has been provided for this image
In [19]:
sns.set_context('paper',font_scale=4.5) 
plt.subplots(figsize=(5,5),dpi=300)
ex_neuronal = ['L2/3 IT', 'L5 IT', 'L5 ET', 'L5/6 NP', 'L6 IT', 'L6 CT', 'L6b']
p2_res = st_data.obs.copy()
# 检查 'subclass' 是否为分类类型,并添加 'None' 类别
if p2_res['cell_type'].dtype.name == 'category':
    p2_res['cell_type'] = p2_res['cell_type'].cat.add_categories(['None'])
p2_res.loc[~p2_res['cell_type'].isin(ex_neuronal), 'cell_type'] = 'None'
sns.scatterplot(data=p2_res, x="x", y="y", hue="cell_type",hue_order=ex_neuronal,s=15,palette=color_dict)
plt.legend(bbox_to_anchor=(1,0.72), loc="upper left",fontsize = '12',framealpha=0,markerscale=3)
plt.gca().invert_yaxis()
plt.axis('off')
plt.savefig('./Result/mop/Celltype-layer.pdf', bbox_inches='tight')
plt.show()
No description has been provided for this image
In [22]:
sns.set_context('paper',font_scale=4.5) 
plt.subplots(figsize=(5,5),dpi=300)
non_neuronal = ['Astro', 'Endo', 'Micro','OPC','Peri','PVM','VLMC','Oligo']
p2_res = st_data.obs.copy()
if p2_res['cell_type'].dtype.name == 'category':
    p2_res['cell_type'] = p2_res['cell_type'].cat.add_categories(['None'])
p2_res.loc[~p2_res['cell_type'].isin(non_neuronal),'cell_type']='None'

sns.scatterplot(data=p2_res, x="x", y="y", hue="cell_type",hue_order=non_neuronal,s=15,palette=color_dict)
plt.legend(bbox_to_anchor=(1,0.72), loc="upper left",fontsize = '12',framealpha=0,markerscale=3)
plt.gca().invert_yaxis()
plt.axis('off')
plt.savefig('./Result/mop/Astro-celltype.pdf', bbox_inches='tight')
plt.show()
No description has been provided for this image
In [23]:
color_dict['None'] = '#e3e1e1'

sns.set_context('paper',font_scale=4.5) 
plt.subplots(figsize=(5,5),dpi=150)
inh_neuronal = ['Lamp5', 'Sncg', 'Vip', 'Sst', 'Pvalb']
p2_res = st_data.obs.copy()
if p2_res['cell_type'].dtype.name == 'category':
    p2_res['cell_type'] = p2_res['cell_type'].cat.add_categories(['None'])
p2_res.loc[~p2_res['cell_type'].isin(inh_neuronal),'cell_type']='None'

sns.scatterplot(data=p2_res, x="x", y="y", hue="cell_type",hue_order=inh_neuronal,s=15,palette=color_dict)
plt.legend(bbox_to_anchor=(.92,0.82), loc="upper left",framealpha=0,markerscale=4.5)
plt.gca().invert_yaxis()
plt.axis('off')
plt.show()
No description has been provided for this image